#include "kernel_operator.h"
constexpr int32_t BUFFER_NUM = 2;

class KernelAddVector {
    public:
        __aicore__ inline KernelAddVector() {}
        __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength, uint32_t tileNum) {
            AscendC::printf("Init start: block number=%u\n", AscendC::GetBlockNum());
            this->blockLength = totalLength / AscendC::GetBlockNum();
            this->tileNum = tileNum;
            this->tileLength = this->blockLength / tileNum / BUFFER_NUM;

            xGm.SetGlobalBuffer((__gm__ DTYPE_X*)x + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
            yGm.SetGlobalBuffer((__gm__ DTYPE_Y*)y + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
            zGm.SetGlobalBuffer((__gm__ DTYPE_Z*)z + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);

            pipe.InitBuffer(inQueueX, BUFFER_NUM, this->tileLength * sizeof(DTYPE_X));
            pipe.InitBuffer(inQueueY, BUFFER_NUM, this->tileLength * sizeof(DTYPE_Y));
            pipe.InitBuffer(outQueueZ, BUFFER_NUM, this->tileLength * sizeof(DTYPE_Z));
        }

        __aicore__ inline void Process() {
            int32_t loopCount = this->tileNum * BUFFER_NUM;
            for (int32_t i =0; i < loopCount; i++) {
                CopyIn(i);
                Compute(i);
                CopyOut(i);
            }


        }

    private:
        __aicore__ inline void CopyIn(int32_t progress) {
            AscendC::LocalTensor<DTYPE_X> xLocal = inQueueX.AllocTensor<DTYPE_X>();
            AscendC::LocalTensor<DTYPE_Y> yLocal = inQueueY.AllocTensor<DTYPE_Y>();
            AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], this->tileLength);
            AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], this->tileLength);
            inQueueX.EnQue(xLocal);
            inQueueY.EnQue(yLocal);
        }

        __aicore__ inline void Compute(int32_t progress) {
            AscendC::LocalTensor<DTYPE_X> xLocal = inQueueX.DeQue<DTYPE_X>();
            AscendC::LocalTensor<DTYPE_Y> yLocal = inQueueY.DeQue<DTYPE_Y>();
            AscendC::LocalTensor<DTYPE_Z> zLocal = outQueueZ.AllocTensor<DTYPE_Z>();
            AscendC::Add(zLocal, xLocal, yLocal, this->tileLength);
            outQueueZ.EnQue<DTYPE_Z>(zLocal);
            inQueueX.FreeTensor(xLocal);
            inQueueY.FreeTensor(yLocal);

        }

        __aicore__ inline void CopyOut(int32_t progress) {
            AscendC::LocalTensor<DTYPE_Z> zLocal = outQueueZ.DeQue<DTYPE_Z>();
            AscendC::DataCopy(zGm[progress * this->tileLength], zLocal, this->tileLength);
            outQueueZ.FreeTensor(zLocal);
        }

    private:
        AscendC::TPipe pipe;
        AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueX, inQueueY;
        AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> outQueueZ;
        AscendC::GlobalTensor<DTYPE_X> xGm;
        AscendC::GlobalTensor<DTYPE_Y> yGm;
        AscendC::GlobalTensor<DTYPE_Z> zGm;
        uint32_t blockLength;
        uint32_t tileNum;
        uint32_t tileLength;

};

extern "C" __global__ __aicore__ void add_vector(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling) {
    GET_TILING_DATA(tiling_data, tiling);
    KernelAddVector op;
    op.Init(x, y, z, tiling_data.totalLength, tiling_data.tileNum);
    op.Process();
}